using Tables, Plots, CSV, DataFrames, MAT, Distributions, LinearAlgebra, SparseArrays, JLD2, HDF5, Random, StatsFuns
include("setup.jl")
include("getBounds.jl")
include("mcmcfunctions.jl")
include("counterfactualFunctionsNew.jl")
include("modelfit.jl")
include("tableHelperFunctions.jl")
BLAS.set_num_threads(1)

function runSimulation(theta,dataset,counterfactuals,lv,temps,year)
    (cf_constants, cf_temps) = counterfactuals[year]
    cf_temps.alphadraws[1:end-1,:] .= rand.()
    cc = dataset[year][1]
    cap0 = [cf_constants.seats+cf_constants.sobrecupo; 0] #baseline, initial match
    II = size(lv.U,2)
    cap0_after = [cf_constants.seats; II] #Note G8 seats are zero.  Fix!
    if year < 2012
        cap0_after[[cf_constants.G8;false]] .=  [sum(cc.enroll.==jj) for jj in findall(cf_constants.G8)]
    end
    df_individuals = DataFrame(ind=1:II, type=cc.typ)
    df_means = DataFrame(inds=["All","Male Private","Male Public","Female Private","Female Public"])
    df_byIter = Dict{Symbol,Array{Float64,1}}()
    dfs = (df_byIter,df_means,df_individuals)
    #run them!
    simulateEverything!(dfs,Symbol("simulate$(year)a"),theta,cc,lv,temps,cf_constants,cf_temps,cap0,cap0_after)
    simulateEverything!(dfs,Symbol("simulate$(year)b"),theta,cc,lv,temps,cf_constants,cf_temps,cap0,cap0_after)
    df_means, df_individuals, df_byIter
end

const datapath = "G:/.shortcut-targets-by-id/19Q7cGujagvIhQ_OhZwxSp2KsSfSxPYpS/4_clean" #"G:/My Drive/4_clean"
const figurepath = "C:/Users/akapor/Documents/GitHub/Platform-Externalities/5_paper/Round 2/Paper/figures/mcmc"
const tablepath = "C:/Users/akapor/Documents/GitHub/Platform-Externalities/5_paper/Round 2/Paper/tables"
const outputpath = datapath*"/from_Julia/chain2"

const model = ProgramFE()
iter = 0:7500 #0:5000
save_params = 2500:50:7500
save_lv = 2500:200:7500
const _sample_fraction = 1.0

const getROLs = false
const saveROLs = false
const computeSimulatedOutcomes = false
const describeSimulatedOutcomes = true
const doExtraCounterfactuals = false
const compare_spda_cpda = false

dataset = Dict{Int,Any}()
counterfactuals = Dict{Int,Any}()

year = 2012
include("_simulateOneYear_inner.jl")

year = 2011
include("_simulateOneYear_inner.jl")

year = 2010
include("_simulateOneYear_inner.jl")

if describeSimulatedOutcomes
    include("describeModelFit.jl")
end

if doExtraCounterfactuals
    c_individuals_extra = Any[]
    for (r,t) in enumerate(save_lv)
        year=2012
        println("running more counterfactuals: iter $t")
        JLD2.@load(outputpath*"/latentvariables$(year)_$t.mat",lv_out)
        theta_t = history[findfirst(save_params.==t)]
        (df_means, df_individuals, df_byIter) = runExtraCounterfactuals(theta_t,dataset,counterfactuals,lv_out)
        df_individuals[!,:draw] = r .* ones(Int,nrow(df_individuals))
        push!(c_individuals_extra,df_individuals)
    end
    df_extra = vcat(c_individuals_extra...)
    CSV.write(outputpath*"/counterfactuals_additional_big.csv",df_extra)
end
